#!/usr/bin/env python3
"""
Unified Scheduler with Prefill-Decoding Optimization
===================================================

Integrated scheduling and resource management functionality with intelligent prefill-decoding:
1. Multi-model scheduler
2. Distributed resource management
3. Task scheduling and orchestration
4. Intelligent prefill-decoding scheduling with TTFT/TBT optimization
5. Asynchronous architecture support
6. Reward-based slot management

Prefill-decoding optimization implements:
S* = arg minS max(TTFT(S), TBT(S)) subject to Σi |KVi| ≤ Mtotal
"""

import asyncio
import logging
import time
import threading
import queue
import math
from typing import Any, Dict, List, Optional, Tuple, Callable, Union
from dataclasses import dataclass, field
from enum import Enum
from collections import defaultdict, deque
from concurrent.futures import ThreadPoolExecutor

logger = logging.getLogger(__name__)

try:
    import aiohttp
    HAS_AIOHTTP = True
except ImportError:
    HAS_AIOHTTP = False


class ModelRole(Enum):
    """Model role"""
    GENERALIST = "generalist"
    SPECIALIST = "specialist"
    COORDINATOR = "coordinator"
    COMPETITOR = "competitor"
    COLLABORATOR = "collaborator"


class InteractionType(Enum):
    """Interaction type"""
    COOPERATION = "cooperation"
    COMPETITION = "competition"
    NEUTRAL = "neutral"


class TaskPriority(Enum):
    """Task priority"""
    CRITICAL = "critical"
    HIGH = "high"
    MEDIUM = "medium"
    LOW = "low"
    BACKGROUND = "background"


class TaskType(Enum):
    """Task types for prefill-decoding scheduling"""
    PREFILL = "prefill"      # Prefill phase (initial processing)
    DECODE = "decode"        # Decode phase (token generation)
    MIXED = "mixed"          # Mixed prefill and decode


class ResourceConstraintType(Enum):
    """Resource constraint types"""
    MEMORY = "memory"
    GPU = "gpu"
    CPU = "cpu"
    BANDWIDTH = "bandwidth"


class SlotState(Enum):
    """Slot state"""
    IDLE = "idle"
    RUNNING = "running"
    BLOCKED = "blocked"
    COMPLETED = "completed"
    FAILED = "failed"


@dataclass
class ModelProfile:
    """Model profile"""
    model_id: str
    gpu_id: int
    port: int
    url: str
    role: ModelRole
    capabilities: Dict[str, float]
    performance_history: List[float] = field(default_factory=list)
    resource_usage: Dict[str, float] = field(default_factory=dict)
    interaction_preferences: Dict[str, float] = field(default_factory=dict)
    is_healthy: bool = True
    last_health_check: float = 0.0
    
    def __post_init__(self):
        if not self.interaction_preferences:
            self.interaction_preferences = {
                InteractionType.COOPERATION.value: 0.5,
                InteractionType.COMPETITION.value: 0.3,
                InteractionType.NEUTRAL.value: 0.2
            }


@dataclass
class TaskDefinition:
    """Task definition"""
    task_id: str
    task_type: str
    complexity: float
    required_capabilities: List[str]
    collaboration_required: bool = False
    competition_allowed: bool = True
    reward_structure: Dict[str, float] = field(default_factory=dict)
    deadline: Optional[float] = None
    priority: TaskPriority = TaskPriority.MEDIUM


@dataclass
class PrefillDecodingTask:
    """Task for prefill-decoding scheduling"""
    task_id: str
    task_type: TaskType
    priority: TaskPriority
    input_length: int
    max_output_length: int
    estimated_tokens: int
    memory_requirement: int  # Memory requirement in bytes
    deadline: float  # Deadline timestamp
    created_time: float
    user_id: Optional[str] = None
    
    def __post_init__(self):
        if self.created_time == 0.0:
            self.created_time = time.time()


@dataclass
class ResourceConstraints:
    """Resource constraints for prefill-decoding scheduling"""
    total_memory: int  # Total available memory
    max_concurrent_tasks: int  # Maximum concurrent tasks
    gpu_memory: int  # GPU memory available
    cpu_memory: int  # CPU memory available
    max_sequence_length: int  # Maximum sequence length


@dataclass
class PerformanceMetrics:
    """Performance metrics for prefill-decoding optimization"""
    ttft_target: float = 0.1  # Target TTFT in seconds
    tbt_target: float = 0.05  # Target TBT in seconds
    memory_efficiency_target: float = 0.8  # Target memory efficiency
    throughput_target: float = 100.0  # Target tokens per second


@dataclass
class SlotInfo:
    """Slot information"""
    slot_id: str
    priority: TaskPriority
    state: SlotState
    reward: float
    created_at: float
    started_at: Optional[float] = None
    completed_at: Optional[float] = None
    execution_time: float = 0.0
    resource_usage: Dict[str, float] = field(default_factory=dict)
    metadata: Dict[str, Any] = field(default_factory=dict)


@dataclass
class InteractionResult:
    """Interaction result"""
    interaction_id: str
    model_ids: List[str]
    interaction_type: InteractionType
    success: bool
    performance_metrics: Dict[str, float]
    resource_consumption: Dict[str, float]
    timestamp: float


class ResourceManager:
    """Unified resource manager"""
    
    def __init__(self, num_gpus: int = 8):
        self.num_gpus = num_gpus
        self.gpu_resources = {
            'compute': [100.0] * num_gpus,
            'memory': [100.0] * num_gpus,
        }
        self.allocated_resources = {
            'compute': [0.0] * num_gpus,
            'memory': [0.0] * num_gpus,
        }
        self.resource_queues = {
            'compute': [[] for _ in range(num_gpus)],
            'memory': [[] for _ in range(num_gpus)],
        }
        self.competition_history = []
        self.lock = threading.Lock()
    
    def request_resources(self, gpu_id: int, resources: Dict[str, float], priority: float = 1.0) -> bool:
        """Request GPU resources"""
        with self.lock:
            if gpu_id >= self.num_gpus:
                return False
            
            for resource_type, amount in resources.items():
                if resource_type not in self.gpu_resources:
                    continue
                
                available = self.gpu_resources[resource_type][gpu_id] - self.allocated_resources[resource_type][gpu_id]
                
                if available >= amount:
                    self.allocated_resources[resource_type][gpu_id] += amount
                    return True
                else:
                    # Resource competition
                    self._enter_competition(gpu_id, resource_type, amount, priority)
                    return False
            
            return True
    
    def _enter_competition(self, gpu_id: int, resource_type: str, amount: float, priority: float):
        """Enter resource competition"""
        self.resource_queues[resource_type][gpu_id].append({
            'gpu_id': gpu_id,
            'amount': amount,
            'priority': priority,
            'timestamp': time.time()
        })
        
        # Sort by priority
        self.resource_queues[resource_type][gpu_id].sort(
            key=lambda x: (x['priority'], -x['timestamp']), 
            reverse=True
        )
        
        self.competition_history.append({
            'gpu_id': gpu_id,
            'resource_type': resource_type,
            'amount': amount,
            'priority': priority,
            'timestamp': time.time()
        })
    
    def release_resources(self, gpu_id: int, resources: Dict[str, float]):
        """Release GPU resources"""
        with self.lock:
            if gpu_id >= self.num_gpus:
                return
            
            for resource_type, amount in resources.items():
                if resource_type in self.allocated_resources:
                    self.allocated_resources[resource_type][gpu_id] = max(
                        0, self.allocated_resources[resource_type][gpu_id] - amount
                    )
    
    def get_resource_stats(self) -> Dict[str, Any]:
        """Get resource statistics"""
        with self.lock:
            return {
                'total_competitions': len(self.competition_history),
                'gpu_allocation_rates': {
                    f'gpu_{gpu_id}': {
                        'compute_rate': self.allocated_resources['compute'][gpu_id] / self.gpu_resources['compute'][gpu_id],
                        'memory_rate': self.allocated_resources['memory'][gpu_id] / self.gpu_resources['memory'][gpu_id]
                    }
                    for gpu_id in range(self.num_gpus)
                }
            }


class SlotManager:
    """Reward-based slot manager"""
    
    def __init__(self, max_slots: int = 10):
        self.max_slots = max_slots
        self.active_slots = {}
        self.waiting_slots = {}
        self.completed_slots = {}
        self.slot_rewards = {}
        self.lock = threading.Lock()
        
        # Statistics
        self.stats = {
            'total_slots': 0,
            'active_slots': 0,
            'completed_slots': 0,
            'failed_slots': 0
        }
    
    def create_slot(self, priority: TaskPriority, reward: float, 
                   resource_usage: Optional[Dict[str, float]] = None,
                   metadata: Optional[Dict[str, Any]] = None) -> str:
        """Create slot"""
        with self.lock:
            slot_id = f"slot_{int(time.time())}_{len(self.active_slots)}"
            
            slot = SlotInfo(
                slot_id=slot_id,
                priority=priority,
                state=SlotState.IDLE,
                reward=reward,
                created_at=time.time(),
                resource_usage=resource_usage or {},
                metadata=metadata or {}
            )
            
            if len(self.active_slots) < self.max_slots:
                self.active_slots[slot_id] = slot
                slot.state = SlotState.RUNNING
                slot.started_at = time.time()
                self.stats['active_slots'] += 1
            else:
                self.waiting_slots[slot_id] = slot
            
            self.slot_rewards[slot_id] = reward
            self.stats['total_slots'] += 1
            
            logger.info(f"Created slot {slot_id}, priority: {priority.value}, reward: {reward}")
            return slot_id
    
    def complete_slot(self, slot_id: str, final_reward: Optional[float] = None) -> bool:
        """Complete slot"""
        with self.lock:
            if slot_id in self.active_slots:
                slot = self.active_slots[slot_id]
                slot.state = SlotState.COMPLETED
                slot.completed_at = time.time()
                slot.execution_time = slot.completed_at - (slot.started_at or slot.created_at)
                
                if final_reward is not None:
                    slot.reward = final_reward
                    self.slot_rewards[slot_id] = final_reward
                
                # Move to completed list
                self.completed_slots[slot_id] = slot
                del self.active_slots[slot_id]
                
                self.stats['active_slots'] -= 1
                self.stats['completed_slots'] += 1
                
                # Start waiting slots
                self._start_waiting_slots()
                
                logger.info(f"Completed slot {slot_id}, final reward: {slot.reward}")
                return True
            
            return False
    
    def _start_waiting_slots(self):
        """Start waiting slots"""
        while len(self.active_slots) < self.max_slots and self.waiting_slots:
            # Select slot with highest priority
            best_slot_id = max(
                self.waiting_slots.keys(),
                key=lambda sid: (
                    self._get_priority_value(self.waiting_slots[sid].priority),
                    self.waiting_slots[sid].reward
                )
            )
            
            slot = self.waiting_slots[best_slot_id]
            slot.state = SlotState.RUNNING
            slot.started_at = time.time()
            
            self.active_slots[best_slot_id] = slot
            del self.waiting_slots[best_slot_id]
            
            self.stats['active_slots'] += 1
    
    def _get_priority_value(self, priority: TaskPriority) -> float:
        """Get priority value"""
        priority_values = {
            TaskPriority.CRITICAL: 1.0,
            TaskPriority.HIGH: 0.8,
            TaskPriority.MEDIUM: 0.6,
            TaskPriority.LOW: 0.4,
            TaskPriority.BACKGROUND: 0.2
        }
        return priority_values.get(priority, 0.5)
    
    def update_slot_reward(self, slot_id: str, new_reward: float) -> bool:
        """Update slot reward"""
        with self.lock:
            if slot_id in self.active_slots:
                self.active_slots[slot_id].reward = new_reward
                self.slot_rewards[slot_id] = new_reward
                return True
            return False
    
    def get_slot_stats(self) -> Dict[str, Any]:
        """Get slot statistics"""
        with self.lock:
            return {
                **self.stats,
                'waiting_slots': len(self.waiting_slots),
                'average_execution_time': self._calculate_avg_execution_time(),
                'average_reward': self._calculate_avg_reward()
            }
    
    def _calculate_avg_execution_time(self) -> float:
        """Calculate average execution time"""
        completed = list(self.completed_slots.values())
        if not completed:
            return 0.0
        return sum(slot.execution_time for slot in completed) / len(completed)
    
    def _calculate_avg_reward(self) -> float:
        """Calculate average reward"""
        if not self.slot_rewards:
            return 0.0
        return sum(self.slot_rewards.values()) / len(self.slot_rewards)


class VLLMClient:
    """Asynchronous VLLM client"""
    
    def __init__(self, endpoint: str, model_name: str, max_retries: int = 3):
        self.endpoint = endpoint
        self.model_name = model_name
        self.max_retries = max_retries
        self.session = None
    
    async def __aenter__(self):
        """Async context manager entry"""
        if not HAS_AIOHTTP:
            raise ImportError("aiohttp required: pip install aiohttp")
        
        self.session = aiohttp.ClientSession()
        return self
    
    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """Async context manager exit"""
        if self.session:
            await self.session.close()
    
    async def generate(self, prompt: str) -> str:
        """Asynchronous text generation"""
        if not HAS_AIOHTTP:
            # Fallback to synchronous simulation
            await asyncio.sleep(0.1)
            return f"Mock response for: {prompt[:50]}..."
        
        payload = {
            "model": self.model_name,
            "messages": [{"role": "user", "content": prompt}],
            "max_tokens": 512,
            "temperature": 0.7
        }
        
        for attempt in range(self.max_retries):
            try:
                async with self.session.post(self.endpoint, json=payload) as response:
                    if response.status == 200:
                        result = await response.json()
                        return result["choices"][0]["message"]["content"]
                    else:
                        logger.warning(f"HTTP {response.status}: {await response.text()}")
            except Exception as e:
                logger.error(f"Attempt {attempt + 1} failed: {e}")
                if attempt < self.max_retries - 1:
                    await asyncio.sleep(1)
                else:
                    raise
        
        raise Exception("All retry attempts failed")


class CapabilityAnalyzer:
    """Capability analyzer"""
    
    def __init__(self):
        self.capability_evolution = {}
        self.specialization_trends = {}
    
    def analyze_model_capabilities(self, model_profile: ModelProfile, 
                                 task_results: List[Dict[str, Any]]) -> Dict[str, float]:
        """Analyze model capabilities"""
        updated_capabilities = model_profile.capabilities.copy()
        
        if not task_results:
            return updated_capabilities
        
        # Update capabilities based on task results
        for result in task_results:
            if result.get('success', False):
                capability_scores = result.get('capability_scores', {})
                for capability, score in capability_scores.items():
                    if capability in updated_capabilities:
                        # Smooth update
                        updated_capabilities[capability] = (
                            updated_capabilities[capability] * 0.8 + score * 0.2
                        )
        
        # Record capability evolution
        model_id = model_profile.model_id
        if model_id not in self.capability_evolution:
            self.capability_evolution[model_id] = []
        
        self.capability_evolution[model_id].append({
            'timestamp': time.time(),
            'capabilities': updated_capabilities.copy()
        })
        
        return updated_capabilities
    
    def detect_functional_differentiation(self, model_profiles: List[ModelProfile]) -> Dict[str, Any]:
        """Detect functional differentiation"""
        if len(model_profiles) < 2:
            return {'differentiation_level': 0.0, 'analysis': 'insufficient_models'}
        
        # Calculate capability differences
        all_capabilities = [profile.capabilities for profile in model_profiles]
        capability_names = set()
        for caps in all_capabilities:
            capability_names.update(caps.keys())
        
        differentiation_scores = {}
        for capability in capability_names:
            values = [caps.get(capability, 0.5) for caps in all_capabilities]
            if len(values) > 1:
                mean_val = sum(values) / len(values)
                variance = sum((v - mean_val) ** 2 for v in values) / len(values)
                differentiation_scores[capability] = variance ** 0.5
        
        overall_differentiation = sum(differentiation_scores.values()) / len(differentiation_scores) if differentiation_scores else 0.0
        
        return {
            'differentiation_level': overall_differentiation,
            'capability_variances': differentiation_scores,
            'analysis': 'high_differentiation' if overall_differentiation > 0.3 else 'low_differentiation'
        }


class InteractionOrchestrator:
    """Interaction orchestrator"""
    
    def __init__(self):
        self.interaction_history = []
    
    def determine_interaction_type(self, model_profiles: List[ModelProfile], 
                                 task: TaskDefinition) -> InteractionType:
        """Determine interaction type"""
        if task.collaboration_required:
            return InteractionType.COOPERATION
        
        if not task.competition_allowed:
            return InteractionType.NEUTRAL
        
        # Decide based on model preferences
        cooperation_preference = sum(
            profile.interaction_preferences.get(InteractionType.COOPERATION.value, 0.5)
            for profile in model_profiles
        ) / len(model_profiles)
        
        competition_preference = sum(
            profile.interaction_preferences.get(InteractionType.COMPETITION.value, 0.3)
            for profile in model_profiles
        ) / len(model_profiles)
        
        if cooperation_preference > competition_preference:
            return InteractionType.COOPERATION
        elif competition_preference > 0.5:
            return InteractionType.COMPETITION
        else:
            return InteractionType.NEUTRAL
    
    def orchestrate_cooperation(self, model_profiles: List[ModelProfile], 
                              task: TaskDefinition) -> Dict[str, Any]:
        """Orchestrate cooperation"""
        # Assign subtasks
        subtask_assignments = {}
        subtasks = self._generate_subtasks(task)
        
        for i, profile in enumerate(model_profiles):
            if i < len(subtasks):
                subtask_assignments[profile.model_id] = [subtasks[i]]
        
        return {
            'execution_mode': 'cooperative',
            'subtask_assignments': subtask_assignments,
            'coordination_strategy': 'sequential',
            'integration_points': ['result_aggregation']
        }
    
    def orchestrate_competition(self, model_profiles: List[ModelProfile], 
                              task: TaskDefinition) -> Dict[str, Any]:
        """Orchestrate competition"""
        return {
            'execution_mode': 'competitive',
            'evaluation_criteria': ['response_quality', 'execution_time'],
            'winner_selection': 'highest_score'
        }
    
    def _generate_subtasks(self, task: TaskDefinition) -> List[str]:
        """Generate subtasks"""
        base_subtasks = ['analysis', 'planning', 'execution', 'validation']
        return base_subtasks[:max(2, min(4, len(task.required_capabilities)))]


class UnifiedScheduler:
    """Unified Scheduler"""
    
    def __init__(self, base_port: int = 8001, num_gpus: int = 8, model_name: str = "qwen-2",
                 max_concurrent_tasks: int = 20, max_slots: int = 10):
        
        self.base_port = base_port
        self.num_gpus = num_gpus
        self.model_name = model_name
        self.max_concurrent_tasks = max_concurrent_tasks
        
        # Initialize components
        self.resource_manager = ResourceManager(num_gpus)
        self.slot_manager = SlotManager(max_slots)
        self.capability_analyzer = CapabilityAnalyzer()
        self.interaction_orchestrator = InteractionOrchestrator()
        
        # Model profile
        self.model_profiles = {}
        
        # Task management
        self.active_tasks = {}
        self.task_history = []
        self.interaction_history = []
        
        # Statistics
        self.stats = {
            'total_tasks': 0,
            'cooperation_tasks': 0,
            'competition_tasks': 0,
            'neutral_tasks': 0
        }
        
        # Concurrency control
        self.semaphore = asyncio.Semaphore(max_concurrent_tasks)
        
        # Prefill-decoding scheduling
        self.prefill_tasks = defaultdict(list)  # priority -> [PrefillDecodingTask]
        self.active_prefill_tasks = {}
        self.completed_prefill_tasks = deque(maxlen=10000)
        
        # Resource constraints for prefill-decoding
        self.resource_constraints = ResourceConstraints(
            total_memory=16 * 1024**3,  # 16GB
            max_concurrent_tasks=max_concurrent_tasks,
            gpu_memory=8 * 1024**3,  # 8GB GPU
            cpu_memory=32 * 1024**3,  # 32GB CPU
            max_sequence_length=32768
        )
        
        # Performance metrics
        self.performance_metrics = PerformanceMetrics()
        
        # Performance tracking
        self.ttft_history = deque(maxlen=1000)
        self.tbt_history = deque(maxlen=1000)
        self.throughput_history = deque(maxlen=1000)
        
        logger.info(f"Unified Scheduler initialized: {num_gpus} GPUs")
    
    def register_model(self, model_id: str, gpu_id: int, 
                      role: ModelRole = ModelRole.GENERALIST,
                      initial_capabilities: Optional[Dict[str, float]] = None) -> bool:
        """Register model"""
        if model_id in self.model_profiles:
            logger.warning(f"Model {model_id} already exists")
            return False
        
        if gpu_id >= self.num_gpus:
            logger.error(f"GPU ID {gpu_id} out of range")
            return False
        
        if initial_capabilities is None:
            initial_capabilities = {
                'reasoning': 0.5,
                'creativity': 0.5,
                'efficiency': 0.5,
                'accuracy': 0.5
            }
        
        profile = ModelProfile(
            model_id=model_id,
            gpu_id=gpu_id,
            port=self.base_port + gpu_id,
            url=f"http://localhost:{self.base_port + gpu_id}/v1",
            role=role,
            capabilities=initial_capabilities
        )
        
        self.model_profiles[model_id] = profile
        logger.info(f"Registered model {model_id} to GPU {gpu_id}")
        
        return True
    
    async def submit_task(self, task: TaskDefinition, 
                         selected_models: Optional[List[str]] = None) -> str:
        """Submit task"""
        if len(self.active_tasks) >= self.max_concurrent_tasks:
            raise RuntimeError("Task queue is full")
        
        # Select participating models
        if selected_models is None:
            selected_models = list(self.model_profiles.keys())
        
        valid_models = [mid for mid in selected_models if mid in self.model_profiles]
        if not valid_models:
            raise ValueError("No valid models available")
        
        # Determine interaction type
        model_profiles = [self.model_profiles[mid] for mid in valid_models]
        interaction_type = self.interaction_orchestrator.determine_interaction_type(model_profiles, task)
        
        # Create slot
        slot_id = self.slot_manager.create_slot(
            priority=task.priority,
            reward=task.reward_structure.get('base_reward', 1.0),
            metadata={'task_id': task.task_id, 'interaction_type': interaction_type.value}
        )
        
        # Create execution plan
        if interaction_type == InteractionType.COOPERATION:
            execution_plan = self.interaction_orchestrator.orchestrate_cooperation(model_profiles, task)
        elif interaction_type == InteractionType.COMPETITION:
            execution_plan = self.interaction_orchestrator.orchestrate_competition(model_profiles, task)
        else:
            execution_plan = {'execution_mode': 'parallel'}
        
        # Start task
        task_executor = asyncio.create_task(
            self._execute_task(task, valid_models, interaction_type, execution_plan, slot_id)
        )
        
        self.active_tasks[task.task_id] = {
            'task': task,
            'models': valid_models,
            'interaction_type': interaction_type,
            'executor': task_executor,
            'slot_id': slot_id,
            'start_time': time.time(),
            'status': 'running'
        }
        
        self.stats['total_tasks'] += 1
        if interaction_type == InteractionType.COOPERATION:
            self.stats['cooperation_tasks'] += 1
        elif interaction_type == InteractionType.COMPETITION:
            self.stats['competition_tasks'] += 1
        else:
            self.stats['neutral_tasks'] += 1
        
        logger.info(f"Submitted task {task.task_id}, interaction type: {interaction_type.value}")
        return task.task_id
    
    async def _execute_task(self, task: TaskDefinition, model_ids: List[str],
                          interaction_type: InteractionType, execution_plan: Dict[str, Any],
                          slot_id: str) -> Dict[str, Any]:
        """Execute task"""
        async with self.semaphore:
            start_time = time.time()
            results = {}
            
            try:
                if interaction_type == InteractionType.COOPERATION:
                    results = await self._execute_cooperation_task(task, model_ids, execution_plan)
                elif interaction_type == InteractionType.COMPETITION:
                    results = await self._execute_competition_task(task, model_ids, execution_plan)
                else:
                    results = await self._execute_neutral_task(task, model_ids, execution_plan)
                
                # Calculate final reward
                final_reward = self._calculate_final_reward(results, task)
                
                # Complete slot
                self.slot_manager.complete_slot(slot_id, final_reward)
                
                # Record interaction result
                interaction_result = InteractionResult(
                    interaction_id=f"{task.task_id}_{interaction_type.value}",
                    model_ids=model_ids,
                    interaction_type=interaction_type,
                    success=True,
                    performance_metrics=results.get('performance_metrics', {}),
                    resource_consumption=results.get('resource_consumption', {}),
                    timestamp=time.time()
                )
                
                self.interaction_history.append(interaction_result)
                
                # Update model capabilities
                await self._update_model_capabilities(model_ids, results)
                
            except Exception as e:
                logger.error(f"Task execution failed: {e}")
                results = {'success': False, 'error': str(e)}
                self.slot_manager.complete_slot(slot_id, 0.0)
            
            finally:
                # Update task status
                if task.task_id in self.active_tasks:
                    self.active_tasks[task.task_id]['status'] = 'completed'
                    self.active_tasks[task.task_id]['end_time'] = time.time()
                    self.active_tasks[task.task_id]['results'] = results
                
                # Record task history
                self.task_history.append({
                    'task_id': task.task_id,
                    'interaction_type': interaction_type.value,
                    'duration': time.time() - start_time,
                    'results': results
                })
            
            return results
    
    async def _execute_cooperation_task(self, task: TaskDefinition, model_ids: List[str],
                                      execution_plan: Dict[str, Any]) -> Dict[str, Any]:
        """Execute cooperative task"""
        logger.info(f"Executing cooperative task: {task.task_id}")
        
        # Execute subtasks in parallel
        subtask_assignments = execution_plan.get('subtask_assignments', {})
        subtask_results = {}
        
        for model_id, subtasks in subtask_assignments.items():
            for subtask in subtasks:
                result = await self._execute_model_subtask(model_id, subtask, task)
                subtask_results[f"{model_id}_{subtask}"] = result
        
        return {
            'success': True,
            'subtask_results': subtask_results,
            'performance_metrics': self._calculate_cooperation_metrics(subtask_results),
            'resource_consumption': self._calculate_resource_consumption(subtask_results)
        }
    
    async def _execute_competition_task(self, task: TaskDefinition, model_ids: List[str],
                                      execution_plan: Dict[str, Any]) -> Dict[str, Any]:
        """Execute competitive task"""
        logger.info(f"Executing competitive task: {task.task_id}")
        
        # Execute competitive tasks in parallel
        competition_results = {}
        competition_tasks = []
        
        for model_id in model_ids:
            task_future = asyncio.create_task(
                self._execute_competitive_task(model_id, task)
            )
            competition_tasks.append((model_id, task_future))
        
        # Wait for all competitive tasks to complete
        for model_id, task_future in competition_tasks:
            try:
                result = await task_future
                competition_results[model_id] = result
            except Exception as e:
                logger.error(f"Competitive task {model_id} failed: {e}")
                competition_results[model_id] = {'success': False, 'error': str(e)}
        
        # Evaluate competition winner
        winner = self._evaluate_competition_winner(competition_results)
        
        return {
            'success': True,
            'competition_results': competition_results,
            'winner': winner,
            'performance_metrics': self._calculate_competition_metrics(competition_results),
            'resource_consumption': self._calculate_resource_consumption(competition_results)
        }
    
    async def _execute_neutral_task(self, task: TaskDefinition, model_ids: List[str],
                                  execution_plan: Dict[str, Any]) -> Dict[str, Any]:
        """Execute neutral task"""
        logger.info(f"Executing neutral task: {task.task_id}")
        
        # Simple parallel execution
        parallel_results = {}
        for model_id in model_ids:
            result = await self._execute_simple_task(model_id, task)
            parallel_results[model_id] = result
        
        return {
            'success': True,
            'parallel_results': parallel_results,
            'performance_metrics': self._calculate_parallel_metrics(parallel_results),
            'resource_consumption': self._calculate_resource_consumption(parallel_results)
        }
    
    async def _execute_model_subtask(self, model_id: str, subtask: str, task: TaskDefinition) -> Dict[str, Any]:
        """Execute model subtask"""
        model_profile = self.model_profiles[model_id]
        
        # Request resources
        required_resources = {'compute': task.complexity * 0.5, 'memory': task.complexity * 0.3}
        if not self.resource_manager.request_resources(model_profile.gpu_id, required_resources):
            return {'success': False, 'error': 'Insufficient GPU resources'}
        
        try:
            # Simulate task execution
            start_time = time.time()
            await asyncio.sleep(0.1 * task.complexity)  # Simulate execution time
            execution_time = time.time() - start_time
            
            return {
                'success': True,
                'result': f"Model {model_id} completed subtask {subtask}",
                'execution_time': execution_time,
                'gpu_id': model_profile.gpu_id,
                'capability_scores': {
                    subtask: min(1.0, 1.0 / execution_time)
                }
            }
        
        finally:
            self.resource_manager.release_resources(model_profile.gpu_id, required_resources)
    
    async def _execute_competitive_task(self, model_id: str, task: TaskDefinition) -> Dict[str, Any]:
        """Execute competitive task"""
        model_profile = self.model_profiles[model_id]
        
        # Competitive mode requires more resources
        required_resources = {'compute': task.complexity * 0.8, 'memory': task.complexity * 0.6}
        if not self.resource_manager.request_resources(model_profile.gpu_id, required_resources, priority=0.8):
            # 降级资源需求
            required_resources = {k: v * 0.3 for k, v in required_resources.items()}
            if not self.resource_manager.request_resources(model_profile.gpu_id, required_resources):
                return {'success': False, 'error': 'GPU资源竞争失败'}
        
        try:
            start_time = time.time()
            await asyncio.sleep(0.1 * task.complexity)
            execution_time = time.time() - start_time
            
            # 计算竞争分数
            competition_score = self._calculate_competition_score(model_id, execution_time, task)
            
            return {
                'success': True,
                'result': f"模型 {model_id} 竞争完成",
                'execution_time': execution_time,
                'gpu_id': model_profile.gpu_id,
                'competition_score': competition_score
            }
        
        finally:
            self.resource_manager.release_resources(model_profile.gpu_id, required_resources)
    
    async def _execute_simple_task(self, model_id: str, task: TaskDefinition) -> Dict[str, Any]:
        """执行简单任务"""
        model_profile = self.model_profiles[model_id]
        
        required_resources = {'compute': task.complexity * 0.4, 'memory': task.complexity * 0.2}
        if not self.resource_manager.request_resources(model_profile.gpu_id, required_resources):
            return {'success': False, 'error': 'Insufficient GPU resources'}
        
        try:
            start_time = time.time()
            await asyncio.sleep(0.05 * task.complexity)
            execution_time = time.time() - start_time
            
            return {
                'success': True,
                'result': f"模型 {model_id} 完成任务",
                'execution_time': execution_time,
                'gpu_id': model_profile.gpu_id
            }
        
        finally:
            self.resource_manager.release_resources(model_profile.gpu_id, required_resources)
    
    def _calculate_competition_score(self, model_id: str, execution_time: float, task: TaskDefinition) -> float:
        """计算竞争分数"""
        model_profile = self.model_profiles[model_id]
        
        # 基于能力和执行时间计算分数
        capability_score = sum(model_profile.capabilities.values()) / len(model_profile.capabilities)
        time_efficiency = max(0, 1.0 - execution_time / 10.0)
        
        return (capability_score + time_efficiency) / 2.0
    
    def _evaluate_competition_winner(self, competition_results: Dict[str, Any]) -> str:
        """评估竞争获胜者"""
        valid_results = {
            model_id: result for model_id, result in competition_results.items()
            if result.get('success', False)
        }
        
        if not valid_results:
            return "none"
        
        winner = max(
            valid_results.items(),
            key=lambda x: x[1].get('competition_score', 0.0)
        )[0]
        
        return winner
    
    def _calculate_final_reward(self, results: Dict[str, Any], task: TaskDefinition) -> float:
        """计算最终奖励"""
        base_reward = task.reward_structure.get('base_reward', 1.0)
        
        if results.get('success', False):
            performance_bonus = sum(results.get('performance_metrics', {}).values()) * 0.1
            return base_reward + performance_bonus
        else:
            return base_reward * 0.1  # 失败惩罚
    
    async def _update_model_capabilities(self, model_ids: List[str], results: Dict[str, Any]):
        """更新模型能力"""
        for model_id in model_ids:
            if model_id not in self.model_profiles:
                continue
            
            profile = self.model_profiles[model_id]
            
            # 基于结果更新能力
            capability_scores = results.get('capability_scores', {})
            if capability_scores:
                for capability, score in capability_scores.items():
                    if capability in profile.capabilities:
                        profile.capabilities[capability] = (
                            profile.capabilities[capability] * 0.7 + score * 0.3
                        )
    
    def _calculate_cooperation_metrics(self, subtask_results: Dict[str, Any]) -> Dict[str, float]:
        """计算合作指标"""
        successful_tasks = sum(1 for result in subtask_results.values() if result.get('success', False))
        total_tasks = len(subtask_results)
        
        return {
            'success_rate': successful_tasks / total_tasks if total_tasks > 0 else 0.0,
            'coordination_efficiency': successful_tasks / max(total_tasks, 1)
        }
    
    def _calculate_competition_metrics(self, competition_results: Dict[str, Any]) -> Dict[str, float]:
        """计算竞争指标"""
        valid_results = [result for result in competition_results.values() if result.get('success', False)]
        
        if not valid_results:
            return {'success_rate': 0.0, 'competition_intensity': 0.0}
        
        scores = [result.get('competition_score', 0.0) for result in valid_results]
        mean_score = sum(scores) / len(scores)
        variance = sum((s - mean_score) ** 2 for s in scores) / len(scores)
        
        return {
            'success_rate': len(valid_results) / len(competition_results),
            'competition_intensity': variance ** 0.5
        }
    
    def _calculate_parallel_metrics(self, parallel_results: Dict[str, Any]) -> Dict[str, float]:
        """计算并行指标"""
        successful_tasks = sum(1 for result in parallel_results.values() if result.get('success', False))
        total_tasks = len(parallel_results)
        
        return {
            'success_rate': successful_tasks / total_tasks if total_tasks > 0 else 0.0,
            'parallel_efficiency': successful_tasks / max(total_tasks, 1)
        }
    
    def _calculate_resource_consumption(self, results: Dict[str, Any]) -> Dict[str, float]:
        """计算资源消耗"""
        total_consumption = {'compute': 0.0, 'memory': 0.0}
        
        for result in results.values():
            if isinstance(result, dict):
                consumption = result.get('resource_consumption', {})
                for resource_type, amount in consumption.items():
                    if resource_type in total_consumption:
                        total_consumption[resource_type] += amount
        
        return total_consumption
    
    async def health_check_all(self) -> Dict[str, bool]:
        """健康检查所有模型"""
        health_status = {}
        
        for model_id, profile in self.model_profiles.items():
            try:
                # 简化的健康检查
                async with VLLMClient(profile.url, self.model_name) as client:
                    await client.generate("health check")
                    health_status[model_id] = True
                    profile.is_healthy = True
            except Exception as e:
                logger.warning(f"模型 {model_id} 健康检查失败: {e}")
                health_status[model_id] = False
                profile.is_healthy = False
        
        return health_status
    
    def get_system_statistics(self) -> Dict[str, Any]:
        """获取系统统计"""
        return {
            'task_statistics': self.stats,
            'resource_statistics': self.resource_manager.get_resource_stats(),
            'slot_statistics': self.slot_manager.get_slot_stats(),
            'model_statistics': {
                'total_models': len(self.model_profiles),
                'healthy_models': sum(1 for p in self.model_profiles.values() if p.is_healthy),
                'gpu_distribution': {
                    f'gpu_{i}': sum(1 for p in self.model_profiles.values() if p.gpu_id == i)
                    for i in range(self.num_gpus)
                }
            },
            'functional_differentiation': self.capability_analyzer.detect_functional_differentiation(
                list(self.model_profiles.values())
            )
        }
    
    async def shutdown(self):
        """关闭调度器"""
        # 等待所有活动任务完成
        active_tasks = list(self.active_tasks.values())
        for task_info in active_tasks:
            if task_info['status'] == 'running':
                try:
                    await task_info['executor']
                except Exception as e:
                    logger.error(f"等待任务完成时出错: {e}")
        
        logger.info("Unified Scheduler已关闭")
    
    # Prefill-decoding scheduling methods
    def add_prefill_task(self, task: PrefillDecodingTask) -> bool:
        """Add prefill-decoding task to scheduler"""
        # Check if task can be scheduled
        if not self._can_schedule_prefill_task(task):
            logger.warning(f"Cannot schedule prefill task {task.task_id}: resource constraints")
            return False
        
        # Add to appropriate priority queue
        self.prefill_tasks[task.priority].append(task)
        logger.info(f"Added prefill task {task.task_id} with priority {task.priority.value}")
        return True
    
    def _can_schedule_prefill_task(self, task: PrefillDecodingTask) -> bool:
        """Check if prefill task can be scheduled given current constraints"""
        # Check memory constraints
        current_memory = sum(t.memory_requirement for t in self.active_prefill_tasks.values())
        if current_memory + task.memory_requirement > self.resource_constraints.total_memory:
            return False
        
        # Check concurrent task limit
        if len(self.active_prefill_tasks) >= self.resource_constraints.max_concurrent_tasks:
            return False
        
        # Check sequence length
        total_length = task.input_length + task.max_output_length
        if total_length > self.resource_constraints.max_sequence_length:
            return False
        
        return True
    
    def get_optimal_prefill_schedule(self) -> List[PrefillDecodingTask]:
        """
        Get optimal prefill-decoding schedule using TTFT/TBT optimization:
        S* = arg minS max(TTFT(S), TBT(S)) subject to Σi |KVi| ≤ Mtotal
        """
        # Collect all pending tasks
        all_tasks = []
        for priority in TaskPriority:
            while self.prefill_tasks[priority]:
                task = self.prefill_tasks[priority].pop(0)
                all_tasks.append(task)
        
        # Sort by deadline (earliest first)
        all_tasks.sort(key=lambda t: t.deadline)
        
        # Apply optimization algorithm
        optimal_schedule = self._optimize_prefill_schedule(all_tasks)
        
        # Put unscheduled tasks back in queues
        for task in all_tasks:
            if task not in optimal_schedule:
                self.prefill_tasks[task.priority].append(task)
        
        return optimal_schedule
    
    def _optimize_prefill_schedule(self, tasks: List[PrefillDecodingTask]) -> List[PrefillDecodingTask]:
        """
        Optimize prefill-decoding schedule using dynamic programming approach
        
        Implements: S* = arg minS max(TTFT(S), TBT(S)) subject to Σi |KVi| ≤ Mtotal
        """
        if not tasks:
            return []
        
        # Sort tasks by deadline
        tasks.sort(key=lambda t: t.deadline)
        
        # Dynamic programming table
        # dp[i][memory] = (min_max_time, schedule)
        dp = {}
        
        # Initialize with empty schedule
        dp[(0, 0)] = (0.0, [])
        
        for i, task in enumerate(tasks):
            new_dp = {}
            
            for (prev_tasks, memory), (max_time, schedule) in dp.items():
                # Option 1: Don't include this task
                new_dp[(prev_tasks, memory)] = (max_time, schedule)
                
                # Option 2: Include this task
                if memory + task.memory_requirement <= self.resource_constraints.total_memory:
                    new_memory = memory + task.memory_requirement
                    new_schedule = schedule + [task]
                    
                    # Calculate TTFT and TBT for this schedule
                    ttft = self._calculate_ttft(new_schedule)
                    tbt = self._calculate_tbt(new_schedule)
                    new_max_time = max(ttft, tbt)
                    
                    key = (i + 1, new_memory)
                    if key not in new_dp or new_max_time < new_dp[key][0]:
                        new_dp[key] = (new_max_time, new_schedule)
            
            dp = new_dp
        
        # Find optimal solution
        if not dp:
            return []
        
        best_schedule = min(dp.values(), key=lambda x: x[0])[1]
        return best_schedule
    
    def _calculate_ttft(self, schedule: List[PrefillDecodingTask]) -> float:
        """Calculate Time-to-First-Token for given schedule"""
        if not schedule:
            return 0.0
        
        # Group tasks by type
        prefill_tasks = [t for t in schedule if t.task_type == TaskType.PREFILL]
        decode_tasks = [t for t in schedule if t.task_type == TaskType.DECODE]
        mixed_tasks = [t for t in schedule if t.task_type == TaskType.MIXED]
        
        # Calculate TTFT based on task types and resource usage
        ttft = 0.0
        
        # Prefill phase contributes to TTFT
        for task in prefill_tasks + mixed_tasks:
            # TTFT increases with input length and memory pressure
            input_factor = math.log(1 + task.input_length)
            memory_factor = sum(t.memory_requirement for t in schedule) / self.resource_constraints.total_memory
            ttft += input_factor * (1 + memory_factor) * 0.01  # Base latency
        
        # Decode phase also contributes to TTFT
        for task in decode_tasks + mixed_tasks:
            # Decode TTFT is typically lower
            ttft += 0.005  # Base decode latency
        
        return ttft
    
    def _calculate_tbt(self, schedule: List[PrefillDecodingTask]) -> float:
        """Calculate Time-Between-Tokens for given schedule"""
        if not schedule:
            return 0.0
        
        # TBT is influenced by concurrent tasks and memory pressure
        concurrent_factor = len(schedule) / self.resource_constraints.max_concurrent_tasks
        memory_factor = sum(t.memory_requirement for t in schedule) / self.resource_constraints.total_memory
        
        # Base TBT increases with concurrency and memory pressure
        base_tbt = 0.05  # Base TBT in seconds
        tbt = base_tbt * (1 + concurrent_factor) * (1 + memory_factor)
        
        return tbt
    
    async def execute_prefill_schedule(self, schedule: List[PrefillDecodingTask]) -> Dict[str, Any]:
        """
        Execute the given prefill-decoding schedule
        
        Args:
            schedule: List of tasks to execute
            
        Returns:
            Execution results and metrics
        """
        start_time = time.time()
        results = {
            "tasks_executed": 0,
            "total_time": 0.0,
            "ttft_achieved": 0.0,
            "tbt_achieved": 0.0,
            "memory_efficiency": 0.0,
            "throughput": 0.0,
            "task_results": []
        }
        
        for task in schedule:
            task_start = time.time()
            
            # Simulate task execution
            await asyncio.sleep(0.01)  # Simulate processing time
            
            task_end = time.time()
            task_time = task_end - task_start
            
            # Update resource usage
            self.active_prefill_tasks[task.task_id] = task
            
            # Record task result
            task_result = {
                "task_id": task.task_id,
                "execution_time": task_time,
                "memory_used": task.memory_requirement,
                "success": True
            }
            results["task_results"].append(task_result)
            results["tasks_executed"] += 1
        
        # Calculate final metrics
        end_time = time.time()
        results["total_time"] = end_time - start_time
        
        # Calculate TTFT and TBT
        results["ttft_achieved"] = self._calculate_ttft(schedule)
        results["tbt_achieved"] = self._calculate_tbt(schedule)
        
        # Calculate memory efficiency
        total_memory_used = sum(t.memory_requirement for t in schedule)
        results["memory_efficiency"] = total_memory_used / self.resource_constraints.total_memory
        
        # Calculate throughput
        total_tokens = sum(t.estimated_tokens for t in schedule)
        results["throughput"] = total_tokens / results["total_time"] if results["total_time"] > 0 else 0
        
        # Update performance history
        self.ttft_history.append(results["ttft_achieved"])
        self.tbt_history.append(results["tbt_achieved"])
        self.throughput_history.append(results["throughput"])
        
        # Clean up resources
        for task in schedule:
            if task.task_id in self.active_prefill_tasks:
                del self.active_prefill_tasks[task.task_id]
            self.completed_prefill_tasks.append(task)
        
        return results
    
    def get_prefill_performance_metrics(self) -> Dict[str, Any]:
        """Get prefill-decoding performance metrics"""
        return {
            "current_memory_usage": sum(t.memory_requirement for t in self.active_prefill_tasks.values()),
            "memory_utilization": sum(t.memory_requirement for t in self.active_prefill_tasks.values()) / self.resource_constraints.total_memory,
            "active_tasks": len(self.active_prefill_tasks),
            "pending_tasks": sum(len(queue) for queue in self.prefill_tasks.values()),
            "avg_ttft": sum(self.ttft_history) / len(self.ttft_history) if self.ttft_history else 0.0,
            "avg_tbt": sum(self.tbt_history) / len(self.tbt_history) if self.tbt_history else 0.0,
            "avg_throughput": sum(self.throughput_history) / len(self.throughput_history) if self.throughput_history else 0.0,
            "ttft_target": self.performance_metrics.ttft_target,
            "tbt_target": self.performance_metrics.tbt_target,
            "throughput_target": self.performance_metrics.throughput_target,
            "performance_score": self._calculate_prefill_performance_score()
        }
    
    def _calculate_prefill_performance_score(self) -> float:
        """Calculate overall prefill-decoding performance score"""
        if not self.ttft_history or not self.tbt_history:
            return 0.0
        
        avg_ttft = sum(self.ttft_history) / len(self.ttft_history)
        avg_tbt = sum(self.tbt_history) / len(self.tbt_history)
        avg_throughput = sum(self.throughput_history) / len(self.throughput_history) if self.throughput_history else 0.0
        
        # Score based on how close we are to targets
        ttft_score = max(0, 1 - (avg_ttft / self.performance_metrics.ttft_target))
        tbt_score = max(0, 1 - (avg_tbt / self.performance_metrics.tbt_target))
        throughput_score = min(1, avg_throughput / self.performance_metrics.throughput_target)
        
        # Weighted average
        return 0.4 * ttft_score + 0.4 * tbt_score + 0.2 * throughput_score
    
    def optimize_prefill_memory_allocation(self):
        """Optimize memory allocation for better TTFT/TBT performance"""
        # Get all pending tasks
        all_tasks = []
        for priority in TaskPriority:
            while self.prefill_tasks[priority]:
                task = self.prefill_tasks[priority].pop(0)
                all_tasks.append(task)
        
        # Sort by memory efficiency (tokens per byte)
        all_tasks.sort(key=lambda t: t.estimated_tokens / max(t.memory_requirement, 1), reverse=True)
        
        # Re-add to queues in optimized order
        for task in all_tasks:
            self.prefill_tasks[task.priority].append(task)
    
    def clear_completed_prefill_tasks(self):
        """Clear completed prefill tasks to free memory"""
        self.completed_prefill_tasks.clear()
        logger.info("Cleared completed prefill tasks")


# 工厂函数
def create_unified_scheduler(base_port: int = 8001, num_gpus: int = 8, 
                           model_name: str = "qwen-2") -> UnifiedScheduler:
    """创建Unified Scheduler"""
    return UnifiedScheduler(base_port, num_gpus, model_name)


def create_cooperative_scheduler(base_port: int = 8001, num_gpus: int = 8) -> UnifiedScheduler:
    """创建合作导向调度器"""
    scheduler = create_unified_scheduler(base_port, num_gpus)
    
    # 注册合作导向的模型
    for i in range(num_gpus):
        scheduler.register_model(
            f"cooperative_model_{i}",
            i,
            ModelRole.COLLABORATOR,
            {'reasoning': 0.7, 'cooperation': 0.8, 'efficiency': 0.6}
        )
    
    return scheduler


def create_competitive_scheduler(base_port: int = 8001, num_gpus: int = 8) -> UnifiedScheduler:
    """创建竞争导向调度器"""
    scheduler = create_unified_scheduler(base_port, num_gpus)
    
    # 注册竞争导向的模型
    for i in range(num_gpus):
        scheduler.register_model(
            f"competitive_model_{i}",
            i,
            ModelRole.COMPETITOR,
            {'reasoning': 0.8, 'competition': 0.9, 'efficiency': 0.7}
        )
    
    return scheduler


def create_task_definition(task_id: str, task_type: str, complexity: float = 0.5,
                          required_capabilities: Optional[List[str]] = None,
                          priority: TaskPriority = TaskPriority.MEDIUM) -> TaskDefinition:
    """创建Task definition"""
    if required_capabilities is None:
        required_capabilities = ['reasoning', 'efficiency']
    
    return TaskDefinition(
        task_id=task_id,
        task_type=task_type,
        complexity=complexity,
        required_capabilities=required_capabilities,
        priority=priority,
        reward_structure={'base_reward': 1.0}
    )
